import torch
import einops
import numpy as np
import torch.nn as nn
from torch import Tensor
from functools import partial
from torchdiffeq import odeint
from contextlib import nullcontext

try:
    from accelerate import init_empty_weights
    from accelerate.utils import set_module_tensor_to_device
    is_accelerate_available = True
except:
    is_accelerate_available = False
    pass

from .unet.openaimodel import UNetModel

def exists(val):
    return val is not None


class DepthFM(nn.Module):
    def __init__(self, vae, ckpt_path: str, device, offload_device,dtype):
        super().__init__()
        self.vae = vae
        self.scale_factor = 0.18215
        self.device = device
        self.offload_device = offload_device

        # set with checkpoint
        state_dict = torch.load(ckpt_path)
        self.noising_step = state_dict['noising_step']
        self.empty_text_embed = state_dict['empty_text_embedding']
        with (init_empty_weights() if is_accelerate_available else nullcontext()):
            self.model = UNetModel(**state_dict['ldm_hparams'])
        if is_accelerate_available:
            for key in state_dict['state_dict']:
                set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=state_dict['state_dict'][key])
        else:
            self.model.load_state_dict(state_dict['state_dict'])
    
    def ode_fn(self, t: Tensor, x: Tensor, **kwargs):
        if t.numel() == 1:
            t = t.expand(x.size(0))
        return self.model(x=x, t=t, **kwargs)
    
    def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs):
        """
        ODE solving from z0 (ims) to z1 (depth).
        """
        ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps))
        
        # t specifies which intermediate times should the solver return
        # e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1
        # but it also specifies the number of steps for fixed step size methods
        t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype)
        # t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype)

        # allow conditioning information for model
        ode_fn = partial(self.ode_fn, **kwargs)
        
        ode_results = odeint(ode_fn, z, t, **ode_kwargs)
        
        if n_intermediates > 0:
            return ode_results
        return ode_results[-1]
    
    def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
        """
        Args:
            ims: Tensor of shape (b, 3, h, w) in range [-1, 1]
        Returns:
            depth: Tensor of shape (b, 1, h, w) in range [0, 1]
        """
        if ensemble_size > 1:
            assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1"
            ims = ims.repeat(ensemble_size, 1, 1, 1)
        
        bs, dev = ims.shape[0], ims.device

        self.vae.first_stage_model = self.vae.first_stage_model.to(self.device)
        ims_z = self.encode(ims, sample_posterior=False)
        self.vae.first_stage_model = self.vae.first_stage_model.to(self.offload_device)

        conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1)
        context = ims_z
        
        x_source = ims_z

        if self.noising_step > 0:
            x_source = q_sample(x_source, self.noising_step)    

        # solve ODE
        self.model.to(self.device)
        depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning)
        self.model.to(self.offload_device)

        self.vae.first_stage_model = self.vae.first_stage_model.to(self.device)
        depth = self.decode(depth_z)
        self.vae.first_stage_model = self.vae.first_stage_model.to(self.offload_device)

        depth = depth.mean(dim=1, keepdim=True)

        if ensemble_size > 1:
            depth = depth.mean(dim=0, keepdim=True)
        
        # normalize depth maps to range [-1, 1]
        depth = per_sample_min_max_normalization(depth.exp())

        return depth
    
    @torch.no_grad()
    def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1):
        """ Inference method for DepthFM. """
        return self.forward(ims, num_steps, ensemble_size)
    
    @torch.no_grad()
    def encode(self, x: Tensor, sample_posterior: bool = True):
        z = self.vae.first_stage_model.encode(x)        
        z = z * self.scale_factor
        return z
    
    @torch.no_grad()
    def decode(self, z: Tensor):
        z = 1.0 / self.scale_factor * z
        return self.vae.first_stage_model.decode(z)


def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def cosine_log_snr(t, eps=0.00001):
    """
    Returns log Signal-to-Noise ratio for time step t and image size 64
    eps: avoid division by zero
    """
    return -2 * np.log(np.tan((np.pi * t) / 2) + eps)


def cosine_alpha_bar(t):
    return sigmoid(cosine_log_snr(t))


def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000):
    """
    Diffuse the data for a given number of diffusion steps. In other
    words sample from q(x_t | x_0).
    """
    dev = x_start.device
    dtype = x_start.dtype

    if noise is None:
        noise = torch.randn_like(x_start)
    
    alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps)
    alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype)

    return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise


def per_sample_min_max_normalization(x):
    """ Normalize each sample in a batch independently
    with min-max normalization to [0, 1] """
    bs, *shape = x.shape
    x_ = einops.rearrange(x, "b ... -> b (...)")
    min_val = einops.reduce(x_, "b ... -> b", "min")[..., None]
    max_val = einops.reduce(x_, "b ... -> b", "max")[..., None]
    x_ = (x_ - min_val) / (max_val - min_val)
    return x_.reshape(bs, *shape)
